import argparse
import paddle
from paddle import nn
import numpy as np
import os
import pickle
import joblib
from data import getDataLoader
from build_vocab import Vocabulary
from model import EncoderCNN, DecoderRNN
from paddle.vision import transforms

import pdb

DEBUG = True
COCO_PATH = "/home/dandelight/coco2017/coco2017"

# Device configuration
device = paddle.set_device('gpu' if paddle.is_compiled_with_cuda() and not DEBUG else 'cpu')



def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image preprocessing, normalization for the pretained resnet
    # TODO: PaddlePaddle Compactativity.
    transform = transforms.Compose([
        # transforms.RandomCrop(args.crop_size),
        # transforms.RandomHorizontalFlip(),
        # transforms.Normalize((0.485, 0.456, 0.406), 
        #                     (0.229, 0.224, 0.225)),
        #  transforms.ToTensor(),
    ])

    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Build the data loader
    data_loader = getDataLoader(
        root = args.image_dir,
        json = args.caption_path,
        vocab = vocab,
        transform = transform,
        batch_size = args.batch_size,
        shuffle=True,
        #  num_workers=1
        num_workers=0
    )

    # Build the models
    encoder = EncoderCNN(args.embed_size) #.to(device)
    decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab), args.num_layers) #.to(device)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    parameters = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
    optimizer = paddle.optimizer.Adam(learning_rate=args.learning_rate, parameters=parameters)

    # Train the models
    total_step = len(data_loader)
    for epoch in range(args.num_epochs):
        pdb.set_trace()
        for i, (images, captions, lengths) in enumerate(data_loader):
            # Set mini-batch dataset
            # targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
            # 谁能告诉我有什么用? (官方文档)

            # Forward, backward and optimize
            features = encoder(images)
            outputs = decoder(features, captions, lengths)
            loss = criterion(outputs, targets)
            decoder.zero_grad()
            encoder.zero_grad()
            optimizer.minimize()
            optimizer.clear_grad()
            loss.backward()
            optimizer.step()

            # Logging
            # TODO: 后续加入VisualDL
            # Print log info
            if i % args.log_step == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                      .format(epoch, args.num_epochs, i, total_step, loss.item(), np.exp(loss.item())))
            
            # Save the model checkpoints
            
            if (i+1) % args.save_step == 0:
                paddle.save(decoder.state_dict(), os.path.join(
                    args.model_path, 'decoder-{}-{}.ckpt'.format(epoch+1, i+1)))
                paddle.save(encoder.state_dict(), os.path.join(
                    args.model_path, 'encoder-{}-{}.ckpt'.format(epoch+1, i+1)))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default='models/' , help='path for saving trained models')
    parser.add_argument('--crop_size', type=int, default=224 , help='size for randomly cropping images')
    parser.add_argument('--vocab_path', type=str, default='{}/vocab.pkl'.format(COCO_PATH), help='path for vocabulary wrapper')
    parser.add_argument('--image_dir', type=str, default='{}/resized2017'.format(COCO_PATH), help='directory for resized images')
    parser.add_argument('--caption_path', type=str, default='{}/captions_train2017.json'.format(COCO_PATH), help='path for train annotation json file')
    parser.add_argument('--log_step', type=int , default=10, help='step size for prining log info')
    parser.add_argument('--save_step', type=int , default=1000, help='step size for saving trained models')
    
    # Model parameters
    parser.add_argument('--embed_size', type=int , default=256, help='dimension of word embedding vectors')
    parser.add_argument('--hidden_size', type=int , default=512, help='dimension of lstm hidden states')
    parser.add_argument('--num_layers', type=int , default=1, help='number of layers in lstm')
    
    parser.add_argument('--num_epochs', type=int, default=5)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--num_workers', type=int, default=2)
    parser.add_argument('--learning_rate', type=float, default=0.001)
    args = parser.parse_args()
    print(args)
    main(args)
